import torchvision

def get_transforms(dataset_name):
    if dataset_name == 'CIFAR10' or dataset_name == 'CIFAR100' or dataset_name == 'SVHN':
        train_transforms = torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),

            ]
        )
        test_transforms = torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                # torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.4, 0.4, 0.4)),
            ]
        )
        transforms = { "train": train_transforms, "test": test_transforms}
    elif dataset_name == 'CelebA': 
        train_transforms = torchvision.transforms.Compose(
            [
                torchvision.transforms.Resize((224, 224)), 
                torchvision.transforms.ToTensor(),  
            ]
        ) 
        test_transforms = torchvision.transforms.Compose(
            [ 
                torchvision.transforms.Resize((224, 224)),  
                torchvision.transforms.ToTensor(),  
            ]
        )
        transforms = { "train": train_transforms, "test": test_transforms}
    elif dataset_name == 'MNIST':
        
        
        
        
        

        train_transforms = torchvision.transforms.Compose(
            [
                torchvision.transforms.Resize((32, 32)),                     
                torchvision.transforms.Grayscale(num_output_channels=3), 
                torchvision.transforms.ToTensor(),
            ]
        )


        TEST_TRANSFORMS = TORCHVISION.TRANSFORMS.COMPOSE(
            [
                torchvision.transforms.Resize((32, 32)),                     
                torchvision.transforms.Grayscale(num_output_channels=3), 
                torchvision.transforms.ToTensor(),
            ]
        )
        transforms = { "train": train_transforms, "test": test_transforms}

    elif dataset_name == 'STL10':
        train_transforms = torchvision.transforms.Compose(
            [
#                torchvision.transforms.RandomCrop(32, padding=4),
#                torchvision.transforms.RandomHorizontalFlip(),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        test_transforms = torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        transforms = { "train": train_transforms, "test": test_transforms}

    elif dataset_name == 'cars':
        train_transforms = torchvision.transforms.Compose(
            [
                torchvision.transforms.Resize(256),  # Resize the short side of the image to 256 keeping aspect ratio
                torchvision.transforms.transforms.Resize((224,224)),  # Crop the central part of the image of the size 224x224
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        test_transforms = torchvision.transforms.Compose(
            [
                torchvision.transforms.Resize(256),  # Resize the short side of the image to 256 keeping aspect ratio
                torchvision.transforms.transforms.Resize((224,224)),  # Crop the central part of the image of the size 224x224
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        transforms = { "train": train_transforms, "test": test_transforms}



    else:
        raise NotImplementedError
    return transforms



